"""Plotting utilities for the fractal‑pivot calibration pipeline.

This module centralises all matplotlib plotting logic used by
``main.py``.  It defines a single helper function for producing
combined scatter and curve plots of the raw box‑counting dimension
estimates alongside the fitted logistic model.  The resulting plot
includes annotation of the fitted parameters and goodness of fit and
is saved as a PNG.
"""

from __future__ import annotations

import matplotlib
matplotlib.use('Agg')  # Use non‑interactive backend for headless environments
import matplotlib.pyplot as plt
from typing import Iterable

def plot_pivot(n_vals: Iterable[float],
               D_vals: Iterable[float],
               pred: Iterable[float],
               k: float,
               n0: float,
               r2: float,
               dataset_name: str,
               output_path: str) -> None:
    """Create and save a pivot plot.

    Parameters
    ----------
    n_vals : Iterable[float]
        Scale indices at which the raw dimension estimates were
        computed.
    D_vals : Iterable[float]
        Observed dimension values corresponding to ``n_vals``.
    pred : Iterable[float]
        Predicted dimension values from the fitted logistic model.
    k : float
        Fitted steepness parameter.
    n0 : float
        Fitted pivot parameter.
    r2 : float
        Coefficient of determination.
    dataset_name : str
        Name of the dataset (used for the plot title).
    output_path : str
        Path on disk where the PNG should be written.
    """
    plt.figure(figsize=(6, 4))
    plt.plot(n_vals, D_vals, 'o', label='D_raw')
    # Sort for smooth curve
    sorted_idx = sorted(range(len(n_vals)), key=lambda i: n_vals[i])
    sorted_n = [n_vals[i] for i in sorted_idx]
    sorted_pred = [pred[i] for i in sorted_idx]
    plt.plot(sorted_n, sorted_pred, '-', label='Logistic fit')
    plt.xlabel('Scale n')
    plt.ylabel('Estimated D(n)')
    plt.title(f"Fractal Pivot Fit: {dataset_name}")
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    # Annotate with parameters
    # Construct a simple annotation string without LaTeX to avoid font issues
    annotation = f"k = {k:.3f}\nn0 = {n0:.3f}\nR2 = {r2:.3f}"
    # Place annotation in upper left corner of the plot
    plt.gca().text(0.05, 0.95, annotation, transform=plt.gca().transAxes,
                   verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()